
import torch
import torch.nn as nn
import math
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=4, in_c=3, embed_dim=192, num_heads=8,norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0]**2, img_size[1] // patch_size[1]**2)
        self.num_patches = self.grid_size[0] * self.grid_size[1]

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
        self.HGR_atten = HGR_atten(dim=embed_dim, qk_bias=False, num_heads=num_heads,patch_size=patch_size[0])

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)

        # HGR
        x = self.HGR_atten(x)
        return x

class Attention(nn.Module):
    def __init__(self,
                 dim,   
                 num_heads=8,
                 qkv_bias=False,
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)
        self.PatchEmbed = PatchEmbed()


    def forward(self, x):
        # [batch_size, num_patches + 1, total_embed_dim]

        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class HGR_atten(nn.Module):
    def __init__(self,
                 dim,
                 qk_bias=False,
                 num_heads = 8,
                 embed_dim = 192,
                 patch_size = 4):
        super(HGR_atten, self).__init__()
        self.qkv = nn.Linear(dim, dim * 2, bias=qk_bias)
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.patch_size =patch_size

    def forward(self,x):
        x = x.reshape(x.shape[0], x.shape[1] //self.patch_size**2,self.patch_size, self.patch_size, x.shape[2])
        B, N,H,W,C = x.shape
        qk = self.qkv(x).reshape(B, N, 2, H, W, self.num_heads, C // self.num_heads).permute(2, 0, 5, 1, 6, 3, 4)
        q, k = qk[0], qk[1]

        attn,__,__ = HGRscore3(q, k)
        attn = attn.permute(0,2,1,3)
        attn = attn.reshape(attn.shape[0],attn.shape[1],-1)

        return attn

def caldistributionCosCorr(f, g):
    
    Number2samples = f.shape[-2]  
    f1 = torch.triu(f, diagonal=1)
    f1 = f1[:, :, :, :, 0:Number2samples-1, 1:Number2samples]
    f1 = f1 + torch.triu(f1, diagonal=1).transpose(-2, -1)  

    g1 = torch.triu(g, diagonal=1)
    g1 = g1[:, :, :, :, 0:Number2samples-1, 1:Number2samples]
    g1 = g1 + torch.triu(g1, diagonal=1).transpose(-2, -1)  

   
    f1 = torch.nn.functional.normalize(f1, dim=-1)
    g1 = torch.nn.functional.normalize(g1, dim=-1)

   
    corr = torch.sum(torch.sum(f1 * g1, dim=-1), dim=-1) / (Number2samples - 1)
    return corr

def HGR3TraceLoss(f, g, expected_relevance=1):
    # f = f.view(f.shape[0],f.shape[1],f.shape[2],f.shape[3],-1)
    # g = g.view(g.shape[0],f.shape[1],f.shape[2],f.shape[3],-1)
    f1 = torch.nn.functional.normalize(f, dim=5)
    g1 = torch.nn.functional.normalize(g, dim=5)

    distribution_f = f1 @ f1.transpose(-2,-1)
    distribution_g = g1 @ g1.transpose(-2,-1)

    tra = caldistributionCosCorr(distribution_f,distribution_g)

    return expected_relevance - tra

def HGRscore3(f, g):

    corr, f1, g1 = calCosCorr(f, g)

    distribution_f = f1 @ f1.transpose(-2,-1)
    distribution_g = g1 @ g1.transpose(-2,-1)

    tra = caldistributionCosCorr(distribution_f,distribution_g)

    # result = corr - tra/2
    result = torch.tensor(1.5) - corr - tra/2
    return result, corr, tra

def calCosCorr(f, g):
    Number2samples = len(f)

    f1 = torch.nn.functional.normalize(f, dim=5)
    g1 = torch.nn.functional.normalize(g, dim=5)

    corr = torch.sum(torch.sum(f1 * g1, 1)) / Number2samples
    return corr, f1, g1



if __name__ == '__main__':
    x = torch.randn(8,3,224,224)
   
    attention = Attention(dim=192)
    PatchEmbed = PatchEmbed()
    
    x = PatchEmbed(x)
    out = attention(x)
    print(out)

